#!/usr/bin/env python

'''Tests the driver API for making connections and exercises the networking code'''

from __future__ import print_function

import datetime, os, random, re, socket, sys, tempfile, time

sys.path.insert(0, os.path.join(os.path.dirname(os.path.realpath(__file__)), os.pardir, os.pardir, "common"))
import driver, rdb_unittest, utils

# -- import the rethinkdb driver

r = utils.import_python_driver()

# -- get settings

rethinkdb_exe = sys.argv[1] if len(sys.argv) > 1 else utils.find_rethinkdb_executable()

# -- shared server

sharedServer       = None
externalServerHost = os.environ.get('RDB_SERVER_HOST', None)
externalServerPort = int(os.environ['RDB_DRIVER_PORT']) if 'RDB_DRIVER_PORT' in os.environ else None

USE_DEFAULT_PORT   = externalServerPort == r.DEFAULT_PORT
USE_DEFAULT_HOST   = externalServerHost in ('localhost', None)

# == Test Base Classes

class TestWithConnection(rdb_unittest.TestCaseCompatible):
    server  = None
    conn    = None
    host    = None
    port    = None
    
    admin_pass     = None
    private_server = False
    use_tls        = False
    
    def setUp(self):
        global sharedServer
        
        if self.private_server or self.use_tls or self.admin_pass:
            if self.server and self.server == sharedServer:
                self.__class__.server = None
            
            # - validate running server
            if self.server:
                try:
                    self.server.check()
                    sslOption = {'ca_certs':self.server.tlsCertPath} if self.server.tlsCertPath else None
                    conn = r.connect(host=self.server.host, port=self.server.driver_port, ssl=sslOption, password=self.admin_pass or '')
                    r.expr(1).run(conn)
                except Exception as err:
                    # ToDo: figure out how to blame the last test
                    try:
                        self.server.stop()
                    except Exception as e:
                        sys.stderr.write('Got error while shutting down shared server: %s' % str(err))
                    try:
                        self.__class__.conn.close()
                    except Exception: pass
                    self.__class__.server = None
                    self.__class__.conn   = None
            
            # - setup private server
            if self.server is None:
                sys.stdout.write('new private server... ')
                sys.stdout.flush()
                self.__class__.server = driver.Process(executable_path=rethinkdb_exe, console_output=self.__class__.__name__ + '_server_console.txt', wait_until_ready=True, tls=self.use_tls)
                
                if self.admin_pass is not None:
                    sslOption = {'ca_certs':self.server.tlsCertPath} if self.server.tlsCertPath else None
                    conn = r.connect(host=self.server.host, port=self.server.driver_port, ssl=sslOption)
                    result = r.db('rethinkdb').table('users').get('admin').update({'password':self.admin_pass}).run(conn)
                    if result != {'skipped': 0, 'deleted': 0, 'unchanged': 0, 'errors': 0, 'replaced': 1, 'inserted': 0}:
                        self.__class__.server = None
                        raise Exception('Unable to set admin password, got: %s' % str(result))
            
            self.__class__.host   = self.server.host
            self.__class__.port   = self.server.driver_port
        
        elif not any([externalServerHost, externalServerPort]):
            if self.server and self.server != sharedServer:
                self.__class__.server = None
            
            # - validate running server
            if sharedServer is not None:
                try:
                    sharedServer.check()
                    sslOption = {'ca_certs':self.server.tlsCertPath} if self.server and self.server.tlsCertPath else None
                    conn = r.connect(host=sharedServer.host, port=sharedServer.driver_port, ssl=sslOption, password=self.admin_pass or '')
                    r.expr(1).run(conn)
                except Exception as e:
                    # ToDo: figure out how to blame the last test
                    try:
                        sharedServer.stop()
                    except Exception as e:
                        sys.stderr.write('Got error while shutting down shared server: %s' % str(e))
                    try:
                        self.__class__.conn.close()
                    except Exception: pass
                    sharedServer = None
                    self.__class__.server = None
                    self.__class__.conn   = None
            
            # - setup shared server
            if sharedServer is None:
                sys.stdout.write('new shared server... ')
                sys.stdout.flush()
                sharedServer = driver.Process(executable_path=rethinkdb_exe, console_output=False, wait_until_ready=True, tls=self.use_tls)
                
            self.__class__.server = sharedServer
            self.__class__.host   = self.server.host
            self.__class__.port   = self.server.driver_port
        
        else:
            # - use external server
            
            self.__class__.server = None
            self.__class__.host   = externalServerHost or 'localhost'
            self.__class__.port   = externalServerPort or r.DEFAULT_PORT
        
        # - establish a connection
        
        if self.conn:
            try:
                r.expr(1).run(self.conn)
            except Exception:
                self.__class__.conn = None
        if not self.conn:
            sslOption = {'ca_certs':self.server.tlsCertPath} if self.server and self.server.tlsCertPath else None
            self.__class__.conn = r.connect(host=self.host, port=self.port, ssl=sslOption, password=self.admin_pass or '')
        
        # - ensure `test` db exists and is empty
        
        r.expr(['test']).set_difference(r.db_list()).for_each(r.db_create(r.row)).run(self.conn)
        r.db('test').table_list().for_each(r.db('test').table_drop(r.row)).run(self.conn)

# == Test Classes

class TestNoConnection(rdb_unittest.TestCaseCompatible):
    def test_connect_port(self):
        port = utils.get_avalible_port()
        self.assertRaisesRegexp(r.ReqlDriverError, "Could not connect to localhost:%d." % port, r.connect, port=port)
    
    def test_connect_timeout(self):
        '''Test that we get a ReQL error if we connect to a non-responsive port'''
        useSocket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        useSocket.bind(('localhost', 0))
        useSocket.listen(0)

        host, port = useSocket.getsockname()

        try:
            self.assertRaisesRegexp(r.ReqlTimeoutError, "Could not connect to %s:%d, operation timed out." % (host, port), r.connect, host=host, port=port, timeout=2)
        finally:
            useSocket.close()

    def test_connect_host(self):
        port = utils.get_avalible_port()
        self.assertRaisesRegexp(r.ReqlDriverError, "Could not connect to 0.0.0.0:%d." % port, r.connect, host="0.0.0.0", port=port)

    def test_empty_run(self):
        '''Test the error message when we pass nothing to run and didn't call `repl`'''
        self.assertRaisesRegexp(r.ReqlDriverError, "RqlQuery.run must be given a connection to run on.", r.expr(1).run)


class TestConnectionDefaults(TestWithConnection):
    if USE_DEFAULT_PORT:
        def test_connect_default_port(self):
            conn = r.connect(host=self.host)
            conn.reconnect()
        
        def test_connect_default_port_wrong_auth(self):
            self.assertRaisesRegexp(r.ReqlAuthError, "Could not connect to %s:%d: Wrong password" % (self.host, self.port), r.connect, host=self.host, password="hunter2")

    if USE_DEFAULT_HOST:
        def test_connect_default_host(self):
            conn = r.connect(port=self.port)
            conn.reconnect()
        
        def test_connect_default_host_wrong_auth(self):
            self.assertRaisesRegexp(r.ReqlAuthError, "Could not connect to %s:%d: Wrong password" % (self.host, self.port), r.connect, port=self.port, password="hunter2")
    
    if USE_DEFAULT_PORT and USE_DEFAULT_HOST:
        def test_connect_default_host_port(self):
            conn = r.connect()
            conn.reconnect()
        
        def test_connect_wrong_auth(self):
            self.assertRaisesRegexp(r.ReqlAuthError, "Could not connect to %s:%d: Wrong password" % (self.host, self.port), r.connect, password="hunter2")


class TestAuthConnection(TestWithConnection):
    incorrectAuthMessage = "Could not connect to %s:%d: Wrong password"
    admin_pass = 'hunter2'
    
    def test_connect_no_auth(self):
        self.assertRaisesRegexp(r.ReqlAuthError, self.incorrectAuthMessage % (self.host, self.port), r.connect, host=self.host, port=self.port)
    
    def test_connect_wrong_auth(self):
        self.assertRaisesRegexp(r.ReqlAuthError, self.incorrectAuthMessage % (self.host, self.port), r.connect, host=self.host, port=self.port, auth_key="")
        self.assertRaisesRegexp(r.ReqlAuthError, self.incorrectAuthMessage % (self.host, self.port), r.connect, host=self.host, port=self.port, auth_key="hunter3")
        self.assertRaisesRegexp(r.ReqlAuthError, self.incorrectAuthMessage % (self.host, self.port), r.connect, host=self.host, port=self.port, auth_key="hunter22")
    
    def test_connect_correct_auth(self):
        conn = r.connect(port=self.port, auth_key="hunter2")
        conn.reconnect()


class TestConnection(TestWithConnection):
    def test_client_port_and_address(self):
        self.assertIsNotNone(self.conn.client_port())
        self.assertIsNotNone(self.conn.client_address())

        self.conn.close()

        self.assertIsNone(self.conn.client_port())
        self.assertIsNone(self.conn.client_address())

    def test_connect_close_reconnect(self):
        r.expr(1).run(self.conn)
        self.conn.close()
        self.conn.close()
        self.conn.reconnect()
        r.expr(1).run(self.conn)

    def test_connect_close_expr(self):
        r.expr(1).run(self.conn)
        self.conn.close()
        self.assertRaisesRegexp(r.ReqlDriverError, "Connection is closed.", r.expr(1).run, self.conn)

    def test_noreply_wait_waits(self):
        t = time.time()
        r.js('while(true);', timeout=0.5).run(self.conn, noreply=True)
        self.conn.noreply_wait()
        duration = time.time() - t
        self.assertGreaterEqual(duration, 0.5)

    def test_close_waits_by_default(self):
        t = time.time()
        r.js('while(true);', timeout=0.5).run(self.conn, noreply=True)
        self.conn.close()
        duration = time.time() - t
        self.assertGreaterEqual(duration, 0.5)

    def test_reconnect_waits_by_default(self):
        t = time.time()
        r.js('while(true);', timeout=0.5).run(self.conn, noreply=True)
        self.conn.reconnect()
        duration = time.time() - t
        self.assertGreaterEqual(duration, 0.5)

    def test_close_does_not_wait_if_requested(self):
        t = time.time()
        r.js('while(true);', timeout=0.5).run(self.conn, noreply=True)
        self.conn.close(noreply_wait=False)
        duration = time.time() - t
        self.assertLess(duration, 0.5)

    def test_reconnect_does_not_wait_if_requested(self):
        t = time.time()
        r.js('while(true);', timeout=0.5).run(self.conn, noreply=True)
        self.conn.reconnect(noreply_wait=False)
        duration = time.time() - t
        self.assertLess(duration, 0.5)

    def test_db(self):
        sslOption = {'ca_certs':self.server.tlsCertPath} if self.server and self.server.tlsCertPath else None
        c = r.connect(host=self.host, port=self.port, ssl=sslOption, password=self.admin_pass or '')
        
        r.db('test').table_create('t1').run(c)

        r.expr(['db2']).set_difference(r.db_list()).for_each(r.db_create(r.row)).run(c)
        r.expr(['t2']).set_intersection(r.db('db2').table_list()).for_each(r.db('db2').table_drop(r.row)).run(c)
        r.db('db2').table_create('t2').run(c)
        
        # Default db should be 'test' so this will work
        r.table('t1').run(c)

        # Use a new database
        c.use('db2')
        r.table('t2').run(c)
        self.assertRaisesRegexp(r.ReqlRuntimeError, "Table `db2.t1` does not exist.", r.table('t1').run, c)

        c.use('test')
        r.table('t1').run(c)
        self.assertRaisesRegexp(r.ReqlRuntimeError, "Table `test.t2` does not exist.", r.table('t2').run, c)

        c.close()

        # Test setting the db in connect
        c = r.connect(db='db2', host=self.host, port=self.port, ssl=sslOption, password=self.admin_pass or '')
        r.table('t2').run(c)

        self.assertRaisesRegexp(r.ReqlRuntimeError, "Table `db2.t1` does not exist.", r.table('t1').run, c)

        c.close()

        # Test setting the db as a `run` option
        r.table('t2').run(self.conn, db='db2')

    def test_outdated_read(self):
        r.db('test').table_create('t1').run(self.conn)

        # Use outdated is an option that can be passed to db.table or `run`
        # We're just testing here if the server actually accepts the option.

        r.table('t1', read_mode='outdated').run(self.conn)
        r.table('t1').run(self.conn, read_mode='outdated')

    def test_repl(self):
        try:
            # Calling .repl() should set this connection as global state to be used when `run` is not otherwise passed a connection.
            sslOption = {'ca_certs':self.server.tlsCertPath} if self.server and self.server.tlsCertPath else None
            c = r.connect(host=self.host, port=self.port, ssl=sslOption, password=self.admin_pass or '')
            c.repl()
            r.expr(1).run()
            
            c.repl() # is idempotent
            r.expr(1).run()
            
            c.close()
            self.assertRaisesRegexp(r.ReqlDriverError, "Connection is closed", r.expr(1).run)
        finally:
            r.Repl.clear()
    
    def test_port_conversion(self):
        self.assertRaisesRegexp(r.ReqlDriverError, "Could not convert port 'abc' to an integer.", r.connect, port='abc', host=self.port)
    
    def test_shutdown(self):
        # use our own server for simplicity
        with driver.Process(executable_path=rethinkdb_exe, console_output=False, wait_until_ready=True, tls=self.use_tls) as server:
            sslOption = {'ca_certs':server.tlsCertPath} if server and server.tlsCertPath else None
            c = r.connect(host=server.host, port=server.driver_port, ssl=sslOption, password=self.admin_pass or '')
            r.expr(1).run(c)
    
            server.stop()
            time.sleep(0.2)

            self.assertRaisesRegexp(r.ReqlDriverError, "Connection is closed.", r.expr(1).run, c)
            
            c.close(noreply_wait=False)


class TestTLSConnection(TestConnection):
    use_tls = True

class TestMisc(TestWithConnection):
    # These don't really have anything to do with connections but this is as good a place as any.
        
    def test_prety_print(self):
        '''Simple prety-printing works'''
        # An exhaustive test of the pretty printer would be, well, exhausting.
        
        output = str(r.db('db1').table('tbl1').map(lambda x: x))
        if not re.match(r"r\.db\('db1'\)\.table\('tbl1'\)\.map\(lambda var_(\d+): var_\1\)", output):
            self.fail('Printing query did not have the expected output, rather: %s' % output)

    def test_pickling(self):
        '''Test the queries can be pickled'''
        
        import pickle
        conn = self.conn
        
        # we need a db and table to run a query against to verify round-trip to the server
        r.expr(['test']).set_difference(r.db_list()).for_each(r.db_create(r.row)).run(conn)
        r.expr(['test']).set_difference(r.db('test').table_list()).for_each(r.db('test').table_create(r.row)).run(conn)
        
        query = r.db('test').table('test').between(r.maxval, r.maxval).limit(2).count()
        expected = query.run(conn)
        self.assertLessEqual(expected, 2) # sanity check
        
        pickled = pickle.dumps(query)
        unpickled = pickle.loads(pickled)
        self.assertEqual(str(unpickled), str(query))
        self.assertEqual(unpickled.run(conn), expected)

class TestBatching(TestWithConnection):

    def test_get_intersecting_batching(self):
        '''Test that get_intersecting() batching works properly'''
        
        r.db('test').table_create('t1').run(self.conn)
        t1 = r.db('test').table('t1')

        t1.index_create('geo', geo=True).run(self.conn)
        t1.index_wait('geo').run(self.conn)

        batch_size = 3
        point_count = 500
        poly_count = 500
        get_tries = 10

        # Insert a couple of random points, so we get a well distributed range of
        # secondary keys. Also insert a couple of large-ish polygons, so we can
        # test filtering of duplicates on the server.
        rseed = random.getrandbits(64)
        random.seed(rseed)
        print("Random seed: " + str(rseed) + ' ...', end=' ')
        sys.stdout.flush()

        t1.insert(
            [{'geo': r.point(random.uniform(-180.0, 180.0), random.uniform(-90.0, 90.0))} for _ in range(point_count)]
        ).run(self.conn)
        
        t1.insert(
            [{'geo': r.circle([random.uniform(-180.0, 180.0), random.uniform(-90.0, 90.0)], 1000000)} for _ in range(poly_count)]
        ).run(self.conn)

        # Check that the results are actually lazy at least some of the time
        # While the test is randomized, chances are extremely high to get a lazy result at least once.
        seen_lazy = False

        for i in range(0, get_tries):
            query_circle = r.circle([random.uniform(-180.0, 180.0), random.uniform(-90.0, 90.0)], 8000000)
            reference = t1.filter(r.row['geo'].intersects(query_circle)).coerce_to("ARRAY").run(self.conn)
            cursor = t1.get_intersecting(query_circle, index='geo').run(self.conn, max_batch_rows=batch_size)
            if cursor.error is None:
                seen_lazy = True

            while len(reference) > 0:
                row = cursor.next()
                self.assertEqual(reference.count(row), 1)
                reference.remove(row)
            self.assertRaises(r.ReqlCursorEmpty, cursor.next)

        self.assertTrue(seen_lazy)
    
    def test_batching(self):
        '''Test the cursor API when there is exactly mod batch size elements in the result stream'''
        
        r.db('test').table_create('t1').run(self.conn)
        t1 = r.table('t1')

        batch_size = 3
        count = 500

        ids = set(range(0, count))

        t1.insert([{'id':i} for i in ids]).run(self.conn)
        cursor = t1.run(self.conn, max_batch_rows=batch_size)

        for _ in range(0, count - 1):
            row = cursor.next()
            self.assertTrue(row['id'] in ids)
            ids.remove(row['id'])

        self.assertEqual(cursor.next()['id'], ids.pop())
        self.assertRaises(r.ReqlCursorEmpty, cursor.next)


class TestGroupWithTimeKey(TestWithConnection):
    def runTest(self):
        r.db('test').table_create('times').run(self.conn)
        
        time1 = 1375115782.24
        rt1 = r.epoch_time(time1).in_timezone('+00:00')
        dt1 = datetime.datetime.fromtimestamp(time1, r.ast.RqlTzinfo('+00:00'))
        time2 = 1375147296.68
        rt2 = r.epoch_time(time2).in_timezone('+00:00')
        dt2 = datetime.datetime.fromtimestamp(time2, r.ast.RqlTzinfo('+00:00'))

        res = r.table('times').insert({'id':0, 'time':rt1}).run(self.conn)
        self.assertEqual(res['inserted'], 1)
        res = r.table('times').insert({'id':1, 'time':rt2}).run(self.conn)
        self.assertEqual(res['inserted'], 1)

        expected_row1 = {'id':0, 'time':dt1}
        expected_row2 = {'id':1, 'time':dt2}

        groups = r.table('times').group('time').coerce_to('array').run(self.conn)
        self.assertEqual(groups, {dt1:[expected_row1], dt2:[expected_row2]})


class TestSuccessAtomFeed(TestWithConnection):
    def runTest(self):
        r.db('test').table_create('success_atom_feed').run(self.conn)
        t1 = r.db('test').table('success_atom_feed')

        res = t1.insert({'id': 0, 'a': 16}).run(self.conn)
        self.assertEqual(res['inserted'], 1)
        res = t1.insert({'id': 1, 'a': 31}).run(self.conn)
        self.assertEqual(res['inserted'], 1)

        t1.index_create('a', lambda x: x['a']).run(self.conn)
        t1.index_wait('a').run(self.conn)

        changes = t1.get(0).changes(include_initial=True).run(self.conn)
        self.assertTrue(changes.error is None)
        self.assertEqual(len(changes.items), 1)


if __name__ == '__main__':
    rdb_unittest.main()
